import pandas as pd
import matplotlib.pyplot as plt

pd.set_option('display.max_colwidth', None)
c = 1000

# Table of most frequent clustered entities
df = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_manual_labels_1000_no_frequency_filter.csv'.format(c))
labels = pd.read_csv('../data/metadata/gpo_manual_cluster_labels/all_clusters_manually_labeled_{0}.csv'.format(c))
labels = labels[['ARG', 'cluster_elements', 'cluster_label_manual']]
entities = pd.read_pickle('../models/entities_sorted.pk')

df1 = df[['ARGO']]
df1.rename(columns={'ARGO': 'ARG'}, inplace=True)
df2 = df[['ARG1']]
df2.rename(columns={'ARG1': 'ARG'}, inplace=True)
df = df1.append(df2)
df = df.fillna('')
df = df.groupby(['ARG']).size().reset_index()
df.columns = ['ARG', 'count']

df = df.sort_values(["count"], ascending=False)

entities = list(entities)
entities = [entity[0] for entity in entities]
entities = entities[0:1000]

df = df.merge(labels, on='ARG')

df = df[['cluster_label_manual', 'cluster_elements', 'count']]
df.columns = ['Label', 'Most frequent phrases', 'Frequency']

# Table of most frequent clustered entities
temp = df[~df['Label'].isin(entities)]
temp = temp.iloc[0:20]
temp.to_latex('../tables/Tables_C_1_2.tex', index=False)

# Table of most frequent named entities
temp = df[df['Label'].isin(entities)]
temp = temp.iloc[0:20]
temp.to_latex('../tables/Tables_C_3_4.tex', index=False)

# Plot entities as agents or patients
df = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_manual_labels_1000_no_frequency_filter.csv', usecols=['ARGO', 'ARG1'])

arg0counts = df['ARGO'].value_counts().reset_index()
arg1counts = df['ARG1'].value_counts().reset_index()
arg0counts.rename(columns={'index': 'arg'}, inplace=True)
arg1counts.rename(columns={'index': 'arg'}, inplace=True)

argcounts = arg0counts.merge(arg1counts, on='arg', how='outer', indicator=True)

argcounts['ARGO'] = argcounts['ARGO'].fillna(0)
argcounts['ARG1'] = argcounts['ARG1'].fillna(0)

argcounts['total'] = argcounts['ARGO'] + argcounts['ARG1']
argcounts['ARGO_share'] = argcounts['ARGO']/argcounts['total']
argcounts = argcounts.sort_values(by=['ARGO_share'], ascending=False)

n, bins, patches = plt.hist(x=argcounts.ARGO_share, bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)
plt.grid(axis='y', alpha=0.75)
plt.xlabel('Share of times an entity appears as an agent (as opposed to a patient)')
plt.ylabel('Frequency')
plt.savefig('../figures/Figure_D_1.pdf')
